import pandas as pd
import numpy as np
import sys

# get research_audits risk scores
# more info on building full_research_audits_df_m, strategy for adding 2014 data in construct_res_rel_research_audits_ls.py
res_rel = pd.read_csv('/REDACTED/res_rel_ls.csv') # pulled from cdw
res_rel = res_rel.rename(columns={'taxpayer_id':'taxpayer_id_new'})
res_rel = res_rel.drop(columns=['Unnamed: 0'])

# remove taxpayers with missing info
res_rel = res_rel[~((res_rel.eitc_ch_elg_rel_test_ch1_ind.isnull()) & (res_rel.eitc_ch_elg_rel_test_ch2_ind.isnull()) & (res_rel.eitc_ch_elg_rel_test_ch3_ind.isnull()) & (res_rel.eitc_restest_met_ch1_ind.isnull()) & (res_rel.eitc_restest_met_ch2_ind.isnull()) & (res_rel.eitc_restest_met_ch3_ind.isnull()))]
# down to 199 left from 2014
res_rel = res_rel[res_rel.taxpayer_id_new.notnull()]
# all of these are from 2014, so down to 14 remaining

# determine how to de-duplicate
taxpayer_ids = res_rel['taxpayer_id_new']
dupes = res_rel[taxpayer_ids.isin(taxpayer_ids[taxpayer_ids.duplicated()])].sort_values('taxpayer_id_new')
dupe_taxpayer_ids = dupes.taxpayer_id_new.unique().tolist()
res_rel[res_rel.taxpayer_id_new.isin(dupe_taxpayer_ids)].taxpayer_id_new.value_counts()

uniform_dfs = []
for taxpayer_id in dupe_taxpayer_ids:
    dupes = res_rel[res_rel.taxpayer_id_new == taxpayer_id]
    #print(len(dupes))
    for col in dupes.columns:
        uniform_columns = []
        if len(dupes[col].unique()) == 1:
            uniform_columns.append(True)
        else:
            uniform_columns.append(False)
    if all(uniform_columns):
        uniform_dfs.append(True)
    else:
        uniform_dfs.append(False)

all(uniform_dfs) # we're good to dedupe however
res_rel = res_rel.drop_duplicates('taxpayer_id_new')
# down to 13 from 2014

# remove those that have no rel test info
res_rel = res_rel[~((res_rel.eitc_ch_elg_rel_test_ch1_ind == 2) & (res_rel.eitc_ch_elg_rel_test_ch1_ind == 2) & (res_rel.eitc_ch_elg_rel_test_ch1_ind == 2))]
# and now we're down to 2 from 2014

# determine number of kids
res_rel['one_kid'] = np.where(res_rel.eitc_ch_elg_rel_test_ch2_ind == 2, 1, 0) 
res_rel['two_kids'] = np.where((res_rel.eitc_ch_elg_rel_test_ch3_ind == 2) & (res_rel.eitc_ch_elg_rel_test_ch2_ind != 2), 2, 0)
res_rel['three_kids'] = np.where(res_rel.eitc_ch_elg_rel_test_ch3_ind != 2, 3, 0)
res_rel['num_kids'] = res_rel['one_kid'] + res_rel['two_kids'] + res_rel['three_kids']

# determine rel test score
res_rel['rel_pass_1'] = np.where(res_rel.eitc_ch_elg_rel_test_ch1_ind == 1, 1, 0)
res_rel['rel_pass_2'] = np.where(res_rel.eitc_ch_elg_rel_test_ch2_ind == 1, 1, 0)
res_rel['rel_pass_3'] = np.where(res_rel.eitc_ch_elg_rel_test_ch3_ind == 1, 1, 0)
res_rel['rel_pass_total'] = res_rel['rel_pass_1'] + res_rel['rel_pass_2'] + res_rel['rel_pass_3']
res_rel['no_pass'] = np.where(res_rel['rel_pass_total'] == 0, "N", "")
res_rel['partial_pass'] = np.where((res_rel['rel_pass_total'] > 0) & (res_rel['rel_pass_total'] < res_rel['num_kids']), "P", "")
res_rel['all_pass'] = np.where(res_rel['rel_pass_total'] == res_rel['num_kids'], "Y", "")
res_rel['rel_test_score'] = res_rel['no_pass'] + res_rel['partial_pass'] + res_rel['all_pass']

# determine res test score
res_rel['res_pass_1'] = np.where(res_rel.eitc_restest_met_ch1_ind == 1, 1, 0)
res_rel['res_pass_2'] = np.where(res_rel.eitc_restest_met_ch2_ind == 1, 1, 0)
res_rel['res_pass_3'] = np.where(res_rel.eitc_restest_met_ch3_ind == 1, 1, 0)
res_rel['res_pass_total'] = res_rel['res_pass_1'] + res_rel['res_pass_2'] + res_rel['res_pass_3']
res_rel['no_pass_res'] = np.where(res_rel['res_pass_total'] == 0, "N", "")
res_rel['partial_pass_res'] = np.where((res_rel['res_pass_total'] > 0) & (res_rel['res_pass_total'] < res_rel['num_kids']), "P", "")
res_rel['all_pass_res'] = np.where(res_rel['res_pass_total'] == res_rel['num_kids'], "Y", "")
res_rel['res_test_score'] = res_rel['no_pass_res'] + res_rel['partial_pass_res'] + res_rel['all_pass_res']

# compute overall risk score
res_rel['overall_risk_score'] = res_rel['res_test_score'] + res_rel['rel_test_score']
res_rel['high_risk_ind_research_audits'] = np.where((res_rel.overall_risk_score == 'NN') | (res_rel.overall_risk_score == 'NP'), 1, 0)

# merge onto research_audits-EITC population
full_research_audits_df = pd.read_csv('/REDACTED/fairness/code/rf/data/clean_rf_data_plus_dep_database.csv')
full_research_audits_df = full_research_audits_df[(full_research_audits_df.activity_code == 270) | (full_research_audits_df.activity_code == 271)]
full_research_audits_df_m = full_research_audits_df.merge(res_rel, how = 'left', on = ['study_year', 'taxpayer_id_new'])


# get dep_database risk scores
dep_database = pd.read_csv('/REDACTED/research_audits_eitc_taxpayer_id_list_plus_res_rel.csv')

# only for one of three people does the de-duplication make a difference for in terms of high-risk/low-risk indicator (YY vs. NN). De-duplicataxpayer_idg on irs_dep_risk_score doesn't work based on our original dataset
dep_database = dep_database.drop_duplicates(['taxpayer_id', 'tax_yr'])
dep_database = dep_database.rename(columns={'taxpayer_id': 'taxpayer_id_new', 'tax_yr':'study_year'})

dep_database['high_risk_ind_dep_database'] = np.where((dep_database.res_relation == 'NN') | (dep_database.res_relation == 'NP'), 1, 0)

# merge dep_database and research_audits risk classifications, assign 'not-high risk (dep_database)' to those not present in dep_database
merged = full_research_audits_df_m.merge(dep_database, on = ['taxpayer_id_new', 'study_year'], how='left')
scores_research_audits = merged[merged.high_risk_ind_research_audits.notnull()]
scores_research_audits = scores_research_audits[scores_research_audits.study_year != 2014] # keep for now until we figure out 2014 data
scores_research_audits.high_risk_ind_dep_database.fillna(0, inplace=True)

# write out results
sys.stdout = open("/REDACTED/research_audits_dep_database_high_risk_classification.txt", "w")
black_weight = (scores_research_audits['predicted_prob_black']*scores_research_audits['base_weight']).sum()
nonblack_weight = (scores_research_audits['predicted_prob_nonblack']*scores_research_audits['base_weight']).sum()

cell0 = scores_research_audits[(scores_research_audits.high_risk_ind_dep_database == 0) & (scores_research_audits.high_risk_ind_research_audits == 0)] # top-left of 2x2 table
cell1 = scores_research_audits[(scores_research_audits.high_risk_ind_dep_database == 0) & (scores_research_audits.high_risk_ind_research_audits == 1)] # top-right of 2x2 table
cell2 = scores_research_audits[(scores_research_audits.high_risk_ind_dep_database == 1) & (scores_research_audits.high_risk_ind_research_audits == 0)] # bottom-left of 2x2 table
cell3 = scores_research_audits[(scores_research_audits.high_risk_ind_dep_database == 1) & (scores_research_audits.high_risk_ind_research_audits == 1)] # bottom-right of 2x2 table

# incorporate research_audits weights
cell0_black_weight = (cell0['predicted_prob_black']*cell0['base_weight']).sum()
cell0_nonblack_weight = (cell0['predicted_prob_nonblack']*cell0['base_weight']).sum()

cell1_black_weight = (cell1['predicted_prob_black']*cell1['base_weight']).sum()
cell1_nonblack_weight = (cell1['predicted_prob_nonblack']*cell1['base_weight']).sum()

cell2_black_weight = (cell2['predicted_prob_black']*cell2['base_weight']).sum()
cell2_nonblack_weight = (cell2['predicted_prob_nonblack']*cell2['base_weight']).sum()

cell3_black_weight = (cell3['predicted_prob_black']*cell3['base_weight']).sum()
cell3_nonblack_weight = (cell3['predicted_prob_nonblack']*cell3['base_weight']).sum()

print("Panel A: Black")
cell0_black = cell0_black_weight / black_weight
cell1_black = cell1_black_weight / black_weight
cell2_black = cell2_black_weight / black_weight
cell3_black = cell3_black_weight / black_weight

print("2x2 Table")
print("Top left: " + str(round(cell0_black, 2)))
print("Top right: " + str(round(cell1_black, 2)))
print("Bottom left: " + str(round(cell2_black, 2)))
print("Bottom right: " + str(round(cell3_black, 2)))

# false-positive and false-negative rates
black_fpr = cell2_black / (cell0_black + cell2_black)
print("\nFPR")
round(black_fpr, 2)

black_fnr = cell1_black / (cell1_black + cell3_black)
print("FNR")
round(black_fnr, 2)

print("\n\nPanel B: Non-Black")
cell0_nonblack = cell0_nonblack_weight / nonblack_weight
cell1_nonblack = cell1_nonblack_weight / nonblack_weight
cell2_nonblack = cell2_nonblack_weight / nonblack_weight
cell3_nonblack = cell3_nonblack_weight / nonblack_weight

print("2x2 Table")
print("Top left: " + str(round(cell0_nonblack, 2)))
print("Top right: " + str(round(cell1_nonblack, 2)))
print("Bottom left: " + str(round(cell2_nonblack, 2)))
print("Bottom right: " + str(round(cell3_nonblack, 2)))

# false-positive and false-negative rates
nonblack_fpr = cell2_nonblack / (cell0_nonblack + cell2_nonblack)
print("\nFPR")
round(nonblack_fpr, 2)

nonblack_fnr = cell1_nonblack / (cell1_nonblack + cell3_nonblack)
print("FNR")
round(nonblack_fnr, 2)


sys.stdout = sys.__stdout__